package resolvers

import (
	"context"
	"math"
	"strings"
	"time"

	"github.com/graph-gophers/graphql-go"
	"github.com/pkg/errors"
	"github.com/stackrox/rox/central/graphql/resolvers/inputtypes"
	"github.com/stackrox/rox/central/graphql/resolvers/loaders"
	"github.com/stackrox/rox/central/metrics"
	"github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/common"
	v1 "github.com/stackrox/rox/generated/api/v1"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/features"
	pkgMetrics "github.com/stackrox/rox/pkg/metrics"
	"github.com/stackrox/rox/pkg/protocompat"
	"github.com/stackrox/rox/pkg/sac"
	"github.com/stackrox/rox/pkg/search"
	"github.com/stackrox/rox/pkg/search/paginated"
	"github.com/stackrox/rox/pkg/search/scoped"
	"github.com/stackrox/rox/pkg/utils"
)

func init() {
	schema := getBuilder()
	utils.Must(
		schema.AddInput("VulnReqExpiry", []string{
			"expiresOn: Time",
			"expiresWhenFixed: Boolean",
		}),
		schema.AddInput("DeferVulnRequest", []string{
			"comment: String",
			"cve: String",
			"expiresOn: Time",
			"expiresWhenFixed: Boolean",
			"scope: VulnReqScope",
		}),

		schema.AddType("DeferralRequest", []string{
			"expiresOn: Time",
			"expiresWhenFixed: Boolean!",
		}),
		schema.AddType("VulnerabilityRequest", []string{
			"id: ID!",
			"targetState: String!",
			"status: String!",
			"expired: Boolean!",
			"requestor: SlimUser",
			"approvers: [SlimUser!]!",
			"createdAt: Time",
			"LastUpdated: Time",
			"comments: [RequestComment!]!",
			"scope: VulnerabilityRequest_Scope",
			"deferralReq: DeferralRequest",
			"falsePositiveReq: FalsePositiveRequest",
			"updatedDeferralReq: DeferralRequest",
			"cves: VulnerabilityRequest_CVEs",

			//// Derived fields

			"deploymentCount(query: String): Int!",
			"imageCount(query: String): Int!",

			"deployments(query: String, pagination: Pagination): [Deployment!]!",
			"images(query: String, pagination: Pagination): [Image!]!",
		}),

		schema.AddMutation("deferVulnerability(request: DeferVulnRequest!): VulnerabilityRequest!"),
		schema.AddMutation("markVulnerabilityFalsePositive(request: FalsePositiveVulnRequest!): VulnerabilityRequest!"),
		schema.AddMutation("approveVulnerabilityRequest(requestID: ID!, comment: String!): VulnerabilityRequest!"),
		schema.AddMutation("denyVulnerabilityRequest(requestID: ID!, comment: String!): VulnerabilityRequest!"),
		schema.AddMutation("updateVulnerabilityRequest(requestID: ID!, comment: String!, expiry: VulnReqExpiry!): VulnerabilityRequest!"),
		schema.AddMutation("undoVulnerabilityRequest(requestID: ID!): VulnerabilityRequest!"),
		schema.AddMutation("deleteVulnerabilityRequest(requestID: ID!): Boolean!"),

		schema.AddQuery("vulnerabilityRequest(id: ID!): VulnerabilityRequest"),
		schema.AddQuery("vulnerabilityRequests(query: String, requestIDSelector: String, pagination: Pagination): [VulnerabilityRequest!]!"),
		schema.AddQuery("vulnerabilityRequestsCount(query: String): Int!"),
	)
}

// DeferVulnerability starts the  workflow to defer a vulnerability.
func (resolver *Resolver) DeferVulnerability(
	_ context.Context,
	_ struct{ Request inputtypes.DeferVulnRequest },
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "DeferVulnerability")

	return nil, nil
}

// MarkVulnerabilityFalsePositive starts the workflow to mark a vulnerability as false-positive.
func (resolver *Resolver) MarkVulnerabilityFalsePositive(
	_ context.Context, _ struct {
		Request inputtypes.FalsePositiveVulnRequest
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "MarkVulnerabilityFalsePositive")

	return nil, nil
}

// ApproveVulnerabilityRequest approves the vulnerability request with the specified ID.
func (resolver *Resolver) ApproveVulnerabilityRequest(
	_ context.Context,
	_ struct {
		RequestID graphql.ID
		Comment   string
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "ApproveVulnerabilityRequest")

	return nil, nil
}

// DenyVulnerabilityRequest denies the vulnerability request with the specified ID.
func (resolver *Resolver) DenyVulnerabilityRequest(
	_ context.Context,
	_ struct {
		RequestID graphql.ID
		Comment   string
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "DenyVulnerabilityRequest")

	return nil, nil
}

// UpdateVulnerabilityRequest updates the vulnerability request with specified ID. Currently, only the expiry of a deferral request can be updated.
func (resolver *Resolver) UpdateVulnerabilityRequest(
	_ context.Context,
	_ struct {
		RequestID graphql.ID
		Comment   string
		Expiry    inputtypes.VulnReqExpiry
	},
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "UpdateVulnerabilityRequest")

	return nil, nil
}

// UndoVulnerabilityRequest undoes/retires the vulnerability request with specified ID. This action does not delete the vulnerability request.
func (resolver *Resolver) UndoVulnerabilityRequest(
	ctx context.Context,
	args struct{ RequestID graphql.ID },
) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "UndoVulnerabilityRequest")

	resID := &v1.ResourceByID{Id: string(args.RequestID)}
	resp, err := resolver.processWithAuditLog(ctx, resID, "UndoVulnerabilityRequest", func() (interface{}, error) {
		if err := writeVulnerabilityRequestsOrApprovals(ctx); err != nil {
			return nil, err
		}
		response, err := resolver.vulnReqMgr.Undo(ctx, resID.GetId(), nil)
		if err != nil {
			return nil, err
		}
		return resolver.wrapVulnerabilityRequest(response, err)
	})

	if resp == nil {
		return nil, err
	}

	return resp.(*VulnerabilityRequestResolver), err
}

// DeleteVulnerabilityRequest deletes the vulnerability request with specified ID.
func (resolver *Resolver) DeleteVulnerabilityRequest(
	ctx context.Context,
	args struct{ RequestID graphql.ID },
) (bool, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "DeleteVulnerabilityRequest")

	resID := &v1.ResourceByID{Id: string(args.RequestID)}
	resp, err := resolver.processWithAuditLog(ctx, resID, "DeleteVulnerabilityRequest", func() (interface{}, error) {
		if err := writeVulnerabilityRequests(ctx); err != nil {
			return false, err
		}
		if err := resolver.vulnReqMgr.Delete(ctx, string(args.RequestID)); err != nil {
			return false, err
		}
		return true, nil
	})

	if resp == nil {
		return false, err
	}

	return resp.(bool), err
}

// VulnerabilityRequest returns the vulnerability request with specified ID.
func (resolver *Resolver) VulnerabilityRequest(ctx context.Context, args struct{ graphql.ID }) (*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "VulnerabilityRequest")

	if err := readVulnerabilityRequestsOrApprovals(ctx); err != nil {
		return nil, err
	}
	response, found, err := resolver.vulnReqStore.Get(ctx, string(args.ID))
	if err != nil || !found {
		return nil, err
	}
	return resolver.wrapVulnerabilityRequest(response, nil)
}

// VulnerabilityRequests returns all vulnerability requests satisfying the specified query.
func (resolver *Resolver) VulnerabilityRequests(ctx context.Context,
	args struct {
		Query             *string
		RequestIDSelector *string
		Pagination        *inputtypes.Pagination
	}) ([]*VulnerabilityRequestResolver, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "VulnerabilityRequests")

	if err := readVulnerabilityRequestsOrApprovals(ctx); err != nil {
		return nil, err
	}

	parsedQuery, err := search.ParseQuery(func() string {
		if args.Query == nil {
			return ""
		}
		return *args.Query
	}(), search.MatchAllIfEmpty())
	if err != nil {
		return nil, err
	}

	if args.RequestIDSelector != nil && *args.RequestIDSelector != "" {
		parsedQuery = search.ConjunctionQuery(
			search.NewQueryBuilder().AddDocIDs(strings.Split(*args.RequestIDSelector, ",")...).ProtoQuery(),
			parsedQuery,
		)
	}

	// Fill in pagination.
	paginated.FillPagination(parsedQuery, args.Pagination.AsV1Pagination(), math.MaxInt32)

	response, err := resolver.vulnReqStore.SearchRawRequests(
		ctx,
		parsedQuery,
	)
	if err != nil {
		return nil, err
	}
	return resolver.wrapVulnerabilityRequests(response, nil)
}

// VulnerabilityRequestsCount returns a count of all vulnerability requests satisfying the specified query.
func (resolver *Resolver) VulnerabilityRequestsCount(ctx context.Context, args RawQuery) (int32, error) {
	defer metrics.SetGraphQLOperationDurationTime(time.Now(), pkgMetrics.Root, "VulnerabilityRequestsCount")

	if err := readVulnerabilityRequestsOrApprovals(ctx); err != nil {
		return 0, err
	}

	q, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return 0, err
	}

	count, err := resolver.vulnReqStore.Count(ctx, q)
	if err != nil {
		return 0, err
	}
	return int32(count), nil
}

// VulnerabilityRequestResolver resolves data about a Vulnerability Requests.
type VulnerabilityRequestResolver struct {
	root *Resolver
	data *storage.VulnerabilityRequest
}

func (resolver *Resolver) wrapVulnerabilityRequest(value *storage.VulnerabilityRequest, err error) (*VulnerabilityRequestResolver, error) {
	if err != nil || value == nil {
		return nil, err
	}
	return &VulnerabilityRequestResolver{root: resolver, data: value}, nil
}

func (resolver *Resolver) wrapVulnerabilityRequests(values []*storage.VulnerabilityRequest, err error) ([]*VulnerabilityRequestResolver, error) {
	if err != nil || len(values) == 0 {
		return nil, err
	}
	ret := make([]*VulnerabilityRequestResolver, 0, len(values))
	for _, value := range values {
		ret = append(ret, &VulnerabilityRequestResolver{root: resolver, data: value})
	}
	return ret, nil
}

// ID returns the ID of the vulnerability request.
func (vr *VulnerabilityRequestResolver) ID(_ context.Context) graphql.ID {
	return graphql.ID(vr.data.GetId())
}

// TargetState returns the requested state for the vulnerability.
func (vr *VulnerabilityRequestResolver) TargetState(_ context.Context) string {
	return vr.data.GetTargetState().String()
}

// Status returns the request status.
func (vr *VulnerabilityRequestResolver) Status(_ context.Context) string {
	return vr.data.GetStatus().String()
}

// Expired returns whether the vulnerability request is expired.
func (vr *VulnerabilityRequestResolver) Expired(_ context.Context) bool {
	return vr.data.GetExpired()
}

// Requestor returns the requestor of the vulnerbility request.
func (vr *VulnerabilityRequestResolver) Requestor(_ context.Context) (*slimUserResolver, error) {
	return vr.root.wrapSlimUser(vr.data.GetRequestor(), true, nil)
}

// Approvers returns the list of approvers of the vulnerbility request, if any.
func (vr *VulnerabilityRequestResolver) Approvers(_ context.Context) ([]*slimUserResolver, error) {
	return vr.root.wrapSlimUsers(vr.data.GetApprovers(), nil)
}

// CreatedAt returns the timestamp when the request was created.
func (vr *VulnerabilityRequestResolver) CreatedAt(_ context.Context) (*graphql.Time, error) {
	return protocompat.ConvertTimestampToGraphqlTimeOrError(vr.data.GetCreatedAt())
}

// LastUpdated returns the timestamp when the request was last updated.
func (vr *VulnerabilityRequestResolver) LastUpdated(_ context.Context) (*graphql.Time, error) {
	return protocompat.ConvertTimestampToGraphqlTimeOrError(vr.data.GetLastUpdated())
}

// Comments returns the request comments.
func (vr *VulnerabilityRequestResolver) Comments(_ context.Context) ([]*requestCommentResolver, error) {
	return vr.root.wrapRequestComments(vr.data.GetComments(), nil)
}

// Scope returns the request's scope.
func (vr *VulnerabilityRequestResolver) Scope(_ context.Context) (*vulnerabilityRequest_ScopeResolver, error) {
	return vr.root.wrapVulnerabilityRequest_Scope(vr.data.GetScope(), true, nil)
}

// DeferralReq returns the deferral request.
func (vr *VulnerabilityRequestResolver) DeferralReq(_ context.Context) (*DeferralRequestResolver, error) {
	return vr.root.wrapDeferralRequest(vr.data.GetDeferralReq(), nil)
}

// FalsePositiveReq returns the false positive request.
func (vr *VulnerabilityRequestResolver) FalsePositiveReq(_ context.Context) (*falsePositiveRequestResolver, error) {
	return vr.root.wrapFalsePositiveRequest(vr.data.GetFpRequest(), true, nil)
}

// UpdatedDeferralReq returns the updated deferral request.
func (vr *VulnerabilityRequestResolver) UpdatedDeferralReq(_ context.Context) (*DeferralRequestResolver, error) {
	return vr.root.wrapDeferralRequest(vr.data.GetUpdatedDeferralReq(), nil)
}

// Cves returns the list of CVEs that the request applies to.
func (vr *VulnerabilityRequestResolver) Cves(_ context.Context) (*vulnerabilityRequest_CVEsResolver, error) {
	return vr.root.wrapVulnerabilityRequest_CVEs(vr.data.GetCves(), true, nil)
}

// DeploymentCount returns the count of deployments impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) DeploymentCount(ctx context.Context, args RawQuery) (int32, error) {
	if err := readDeployments(ctx); err != nil {
		return 0, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return 0, err
	}
	count, err := vr.root.vulnReqQueryMgr.DeploymentCount(ctx, vr.data.GetId(), query)
	if err != nil {
		return 0, err
	}
	return int32(count), nil
}

// ImageCount returns the count of images impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) ImageCount(ctx context.Context, args RawQuery) (int32, error) {
	if err := readImages(ctx); err != nil {
		return 0, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return 0, err
	}
	count, err := vr.root.vulnReqQueryMgr.ImageCount(ctx, vr.data.GetId(), query)
	if err != nil {
		return 0, err
	}
	return int32(count), nil
}

// Deployments returns the deployments impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) Deployments(ctx context.Context, args PaginatedQuery) ([]*deploymentResolver, error) {
	if err := readDeployments(ctx); err != nil {
		return nil, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return nil, err
	}
	return vr.root.wrapDeployments(vr.root.vulnReqQueryMgr.Deployments(ctx, vr.data.GetId(), query))
}

// Images returns the images impacted by this vulnerability request for the specified query.
func (vr *VulnerabilityRequestResolver) Images(ctx context.Context, args PaginatedQuery) ([]ImageResolver, error) {
	if err := readImages(ctx); err != nil {
		return nil, err
	}
	query, err := args.AsV1QueryOrEmpty()
	if err != nil {
		return nil, err
	}
	if features.FlattenImageData.Enabled() {
		resolvers, err := vr.root.wrapImageV2s(vr.root.vulnReqQueryMgr.ImageV2s(ctx, vr.data.GetId(), query))
		res := make([]ImageResolver, 0, len(resolvers))
		for _, resolver := range resolvers {
			res = append(res, resolver)
		}
		return res, err
	}
	resolvers, err := vr.root.wrapImages(vr.root.vulnReqQueryMgr.Images(ctx, vr.data.GetId(), query))
	res := make([]ImageResolver, 0, len(resolvers))
	for _, resolver := range resolvers {
		res = append(res, resolver)
	}
	return res, err
}

// DeferralRequestResolver resolves data about a Vulnerability deferral requests.
type DeferralRequestResolver struct {
	root *Resolver
	data *storage.DeferralRequest
}

func (resolver *Resolver) wrapDeferralRequest(value *storage.DeferralRequest, err error) (*DeferralRequestResolver, error) {
	if err != nil || value == nil {
		return nil, err
	}
	return &DeferralRequestResolver{root: resolver, data: value}, nil
}

// ExpiresOn returns the deferral request expiry timestamp if the request had a time-based expiry.
func (dr *DeferralRequestResolver) ExpiresOn(_ context.Context) (*graphql.Time, error) {
	return protocompat.ConvertTimestampToGraphqlTimeOrError(dr.data.GetExpiry().GetExpiresOn())
}

// ExpiresWhenFixed returns true if the deferral request expires when vulnerability is fixable.
func (dr *DeferralRequestResolver) ExpiresWhenFixed(_ context.Context) bool {
	return dr.data.GetExpiry().GetExpiresWhenFixed()
}

type exceptionQueryFilters struct {
	cves          []string
	requestStates []string
}

// unExpiredExceptionQuery returns a query to retrieve unexpired vulnerability exceptions. If an image scope is found
// in `ctx`, then the query for exceptions covering the specified image is returned.
// Filters`cves` and `requestStates` are optional.
func unExpiredExceptionQuery(ctx context.Context, filters exceptionQueryFilters) (*v1.Query, error) {
	qb := search.NewQueryBuilder().AddBools(search.ExpiredRequest, false)

	if len(filters.cves) > 0 {
		qb.AddExactMatches(search.CVE, filters.cves...)
	}
	if len(filters.requestStates) > 0 {
		qb.AddExactMatches(search.RequestStatus, filters.requestStates...)
	}

	var imageID string
	searchCategory := v1.SearchCategory_IMAGES
	if features.FlattenImageData.Enabled() {
		searchCategory = v1.SearchCategory_IMAGES_V2
	}
	scope, hasScope := scoped.GetScopeAtLevel(ctx, searchCategory)
	if hasScope {
		if len(scope.IDs) != 1 {
			return nil, errors.New("invalid request, only a single image is allowed")
		}
		imageID = scope.IDs[0]
	}
	// Return query with scope part.
	if imageID == "" {
		return qb.ProtoQuery(), nil
	}

	if features.FlattenImageData.Enabled() {
		imageLoader, err := loaders.GetImageV2Loader(ctx)
		if err != nil {
			return nil, errors.Wrap(err, "getting image loader")
		}
		img, err := imageLoader.FromID(sac.WithAllAccess(context.Background()), imageID)
		if err != nil {
			return nil, errors.Wrapf(err, "fetching image with id %s", imageID)
		}
		imageName := img.GetName()

		specificTag := search.NewQueryBuilder().
			AddExactMatches(search.ImageRegistryScope, imageName.GetRegistry()).
			AddExactMatches(search.ImageRemoteScope, imageName.GetRemote())

		if imageName.GetTag() == "" {
			specificTag.AddExactMatches(search.ImageTagScope, "")
		} else {
			specificTag.AddExactMatches(search.ImageTagScope, imageName.GetTag())
		}

		// Add clauses to search global exception, all tags exception, and specific image exception that cover the concerned image.
		scopeQ := search.DisjunctionQuery(
			search.NewQueryBuilder().
				AddExactMatches(search.ImageRegistryScope, common.MatchAll).
				AddExactMatches(search.ImageRemoteScope, common.MatchAll).
				AddExactMatches(search.ImageTagScope, common.MatchAll).ProtoQuery(),
			search.NewQueryBuilder().
				AddExactMatches(search.ImageRegistryScope, imageName.GetRegistry()).
				AddExactMatches(search.ImageRemoteScope, imageName.GetRemote()).
				AddExactMatches(search.ImageTagScope, common.MatchAll).ProtoQuery(),
			specificTag.ProtoQuery(),
		)

		return search.ConjunctionQuery(qb.ProtoQuery(), scopeQ), nil
	}
	imageLoader, err := loaders.GetImageLoader(ctx)
	if err != nil {
		return nil, errors.Wrap(err, "getting image loader")
	}
	img, err := imageLoader.FromID(sac.WithAllAccess(context.Background()), imageID)
	if err != nil {
		return nil, errors.Wrapf(err, "fetching image with id %s", imageID)
	}
	imageName := img.GetName()

	specificTag := search.NewQueryBuilder().
		AddExactMatches(search.ImageRegistryScope, imageName.GetRegistry()).
		AddExactMatches(search.ImageRemoteScope, imageName.GetRemote())

	if imageName.GetTag() == "" {
		specificTag.AddExactMatches(search.ImageTagScope, "")
	} else {
		specificTag.AddExactMatches(search.ImageTagScope, imageName.GetTag())
	}

	// Add clauses to search global exception, all tags exception, and specific image exception that cover the concerned image.
	scopeQ := search.DisjunctionQuery(
		search.NewQueryBuilder().
			AddExactMatches(search.ImageRegistryScope, common.MatchAll).
			AddExactMatches(search.ImageRemoteScope, common.MatchAll).
			AddExactMatches(search.ImageTagScope, common.MatchAll).ProtoQuery(),
		search.NewQueryBuilder().
			AddExactMatches(search.ImageRegistryScope, imageName.GetRegistry()).
			AddExactMatches(search.ImageRemoteScope, imageName.GetRemote()).
			AddExactMatches(search.ImageTagScope, common.MatchAll).ProtoQuery(),
		specificTag.ProtoQuery(),
	)

	return search.ConjunctionQuery(qb.ProtoQuery(), scopeQ), nil
}
